# import packages and modules
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
import sys
sys.path.insert(1, '/REDACTED/fairness/code/utilities/')
import UncertaintySimulation as unc
import os
import pickle
from joblib import dump, load
import yaml
import statsmodels.api as sm 
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.linear_model import LinearRegression as reg
import pdb
from pprint import pprint
from statistics import mean
import random
random.seed(50)
import math
import argparse
from tqdm import tqdm

# probabilistic estimator
def chenEstimate(dataset, pbvarb='pBlack', outcome='audited', wvarb=None):
    if wvarb is not None:
        return (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()-((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
    else:
        return (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()-((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()

# linear estimator
def regEstimate(dataset,pbvarb='pBlack',outcome='audited', wvarb=None):
    if wvarb is not None: 
        model =  smf.wls(outcome + ' ~ ' + pbvarb, dataset, weights=dataset[wvarb]).fit(cov_type='HC1')
        coef= model.params[pbvarb]
        se = model.bse[pbvarb]
        return coef, se
    else:
        model =  smf.ols(outcome + ' ~ ' + pbvarb, dataset).fit(cov_type='HC1')
        coef= model.params[pbvarb]
        se = model.bse[pbvarb]
        return coef, se

# define functions for calculataxpayer_idg standard errors
def getWVar(values, weights):
    average = np.average(values, weights=weights)
    variance = np.average((values-average)**2, weights=weights)
    return variance

def getSEMultiplier(dataset, pbvarb='pBlack', wvarb=None):
    if wvarb is not None:
        #return np.sqrt(getWVar(dataset[pbvarb], dataset[wvarb])/((dataset[pbvarb]*dataset[wvarb]).mean()*((1-dataset[pbvarb])*dataset[wvarb]).mean()))
        return np.sqrt(getWVar(dataset[pbvarb], dataset[wvarb])/(np.average(dataset[pbvarb], weights=dataset[wvarb])*np.average(1-dataset[pbvarb], weights=dataset[wvarb])))
    else:
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

def getSEs(dataset, pbvarb='pBlack', outcome='audited', wvarb=None):
    seMultiplier=getSEMultiplier(dataset,pbvarb,wvarb=wvarb)
    seReg = regEstimate(dataset,pbvarb,outcome, wvarb=wvarb)[1]
    seChen = seReg*seMultiplier
    return seChen, seReg

# get 20 most important features
defaultaxpayer_id='/REDACTED/fairness/code/rf/data/'
defaultout='/REDACTED/'

def get_feature_names(configpath='/REDACTED/fairness/code/rf/config/',
                      datapath='/REDACTED/fairness/code/rf/data/',
                      dep_database=False):
    stream=open(configpath+'data-config.yaml', 'r')
    out = yaml.load(stream)
    
    if dep_database==False:
        df=pd.read_csv(datapath+'clean_rf_data.csv')
        feature_vars = [x for x in out['features_str'] if x in df.columns]
        features = df[feature_vars]
    
    elif dep_database==True:
        df=pd.read_csv(datapath+'clean_rf_data_plus_dep_database.csv')
        feature_vars = [x for x in out['features_plus_dep_database_str'] if x in df.columns]
        features = df[feature_vars]
        
    feature_names=features.columns
        
    return feature_names

def get_feature_importance(indir=defaultaxpayer_id,
                            outdir=defaultout,
                            modelname='EITC_NCMP_RF_Class_100_plus_dep_database',
                            feature_names=None):
    model=load(defaultaxpayer_id+modelname+'.joblib')
    importances=pd.Series(model.feature_importances_, index=feature_names)
    importances.nlargest(20).plot(kind='barh', title=modelname).get_figure().savefig(outdir+modelname+'_feat_imp.png', bbox_inches='tight')
    importances.to_csv(outdir+modelname+'_importances.csv')
    return importances

feature_names = get_feature_names(dep_database=False)
feature_names_dep_database = get_feature_names(dep_database=True)

# compute feature importances in each of the 5 folds of the refundable regressor model
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_0_outcome_ref_cred_amt_dif_pv', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_1_outcome_ref_cred_amt_dif_pv', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_2_outcome_ref_cred_amt_dif_pv', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_3_outcome_ref_cred_amt_dif_pv', feature_names=feature_names_dep_database)
get_feature_importance(modelname='EITC_NCMP_RF_Reg_plus_dep_database_train_set_4_outcome_ref_cred_amt_dif_pv', feature_names=feature_names_dep_database)

fold0 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_0_outcome_ref_cred_amt_dif_pv_importances.csv')
fold1 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_1_outcome_ref_cred_amt_dif_pv_importances.csv')
fold1 = fold1.rename(columns={'0': '1'})
fold2 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_2_outcome_ref_cred_amt_dif_pv_importances.csv')
fold2 = fold2.rename(columns={'0': '2'})
fold3 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_3_outcome_ref_cred_amt_dif_pv_importances.csv')
fold3 = fold3.rename(columns={'0': '3'})
fold4 = pd.read_csv('/REDACTED/EITC_NCMP_RF_Reg_plus_dep_database_train_set_4_outcome_ref_cred_amt_dif_pv_importances.csv')
fold4 = fold4.rename(columns={'0': '4'})

# merge together feature importances from all 5 folds, average importances together
folds = fold0.merge(fold1, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold2, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold3, how = 'left', on = 'Unnamed: 0')
folds = folds.merge(fold4, how = 'left', on = 'Unnamed: 0')

folds['importance_avg'] = (folds['0'] + folds['1'] + folds['2'] + folds['3'] + folds['4']) / 5
folds = folds.rename(columns={'Unnamed: 0': 'feature'})

importances=pd.Series(folds['importance_avg'].tolist(), index=folds['feature'])
importances.nlargest(40).plot(kind='barh', title='Feature Importance (Average of 5 Folds)').get_figure().savefig('/REDACTED/avg_feat_imp_test2.png', bbox_inches='tight')

# grab 20 most important features
list_feature_dep_database = importances.nlargest(20).reset_index()['feature'].tolist()

# read in main rf dataset
full_research_audits_df = pd.read_csv('/REDACTED/fairness/code/rf/data/clean_rf_data_plus_dep_database.csv')

# calculate association between features and BIFSG-predicted probability black
feature_disp_df = pd.DataFrame(columns = ['Feature', 'Feature SD', 'Linear Disparity', 'Linear SE', 'Probabilistic Disparity', 'Probabilistic SE', 'SE Multiplier'])
copy = full_research_audits_df.copy(deep = True)
for feature in list_feature_dep_database:
    feature_stddev = np.std(copy[feature])
    lin = unc.regEstimate(copy, 'predicted_prob_black', feature, 'base_weight')[0]
    prob = unc.chenEstimate(copy, 'predicted_prob_black', feature, 'base_weight')
    probSE, linSE = getSEs(copy, 'predicted_prob_black', feature, 'base_weight')
    se_mult = getSEMultiplier(copy, 'predicted_prob_black', 'base_weight')
    feature_disp_df.loc[len(feature_disp_df.index)] = [feature, feature_stddev, lin, linSE, prob, probSE, se_mult]


# normalize associations by dividing by standard deviation
feature_disp_df['Linear Normal'] = feature_disp_df['Linear Disparity']/feature_disp_df['Feature SD']
feature_disp_df['Probabilistic Normal'] = feature_disp_df['Probabilistic Disparity']/feature_disp_df['Feature SD']
feature_disp_df['Linear SE Normal'] = feature_disp_df['Linear SE']/feature_disp_df['Feature SD']
feature_disp_df['Probabilistic SE Normal'] = feature_disp_df['Probabilistic SE']/feature_disp_df['Feature SD']

# plot associations
fig,ax = plt.subplots(1,1,sharex = True, sharey = True, figsize=(15,10))
x = feature_disp_df.sort_values('Linear Normal')['Feature']
y_l = feature_disp_df.sort_values('Linear Normal')['Linear Normal']
yerr_l = feature_disp_df.sort_values('Linear Normal')['Linear SE Normal'] * 1.96
y_p = feature_disp_df.sort_values('Linear Normal')['Probabilistic Normal']
yerr_p = feature_disp_df.sort_values('Linear Normal')['Probabilistic SE Normal'] * 1.96

plt.xticks(rotation=90)
plt.errorbar(x, y_l, yerr=yerr_l, fmt="o", capsize = 9, label = "Linear")
plt.errorbar(x, y_p, yerr=yerr_p, fmt="o", capsize = 9, label = "Probabilistic")
plt.xlabel('Top 20 Features')
plt.ylabel('Association with Black')
ax.axhline(y = 0, color = 'red', linestyle = 'dotted')
plt.legend()

# add custom labels
labels = [item.get_text() for item in ax.get_xticklabels()]
for i in range(len(labels)):
    labels[i] = "Feature " + str(i + 1)

ax.set_xticklabels(labels)

plt.savefig('/REDACTED/feature_associations_top20_plot_normalized.png', facecolor='white', bbox_inches='tight')

plt.close()